"""
Based on the file logistic.cpp at https://github.com/stanford-futuredata/wmsketch/blob/master/src/logistic.cpp

This file implements the same functionality as logistic.cpp above (with some functions omitted), and also has
an additional function which zeroes out all but the top K weights for some parameter K - this is used later for
seeing the test error when using only the top K weights.
"""

import numpy as np

# Here the loss function is l(x) = log(1 + e^(-x)). Thus
# the derivative is 1/(1 + e^(-x)) * (-e^(-x)).
def logistic_gradient(x):
    return (-1) * np.exp(-1 * x) / (1 + np.exp(-1 * x))

class OnlineLogisticRegression:

    def __init__(self, lr_init, l2_reg, no_bias):
        # Initialized to empty dictionary
        self.weights = dict()

        # no_bias is true if bias is always 0
        self.no_bias = no_bias
        self.bias = 0.0

        self.lr_init = lr_init

        self.l2_reg = l2_reg

        # Amount that self.weights should be scaled by to get
        # the actual weights. self.bias is already scaled
        # appropriately.
        self.scale = 1.0

        # Used for learning rate schedule during training
        self.t = 0.0
    
    def get_weight(self, index):
        if index in self.weights:
            return self.scale * self.weights[index]
        else:
            return 0.0
    
    def get_bias(self, index):
        return self.bias

    # Example is a dictionary that maps the indices of
    # the example to the corresponding coordinates.
    # Returns +1 or -1.
    def predict(self, example):
        dot = 0
        for index, value in example.items():
            dot += self.get_weight(index) * value  # Note that self.get_weight already does the scaling.
        if dot + self.bias >= 0:
            return 1
        else:
            return -1
    
    # Increments the weight, whose coordinate is index, by inc.
    def increment_weight(self, index, inc):
        if index in self.weights:
            self.weights[index] += inc
        else:
            self.weights[index] = inc

    # Here example is a dictionary that maps the indices
    # of the example to the corresponding coordinates.
    # Label is +1 or -1 (note that in logistic.cpp it is 0 or 1).
    def update(self, example, label):

        # Use learning rate schedule from logistic.cpp
        lr = self.lr_init / (1.0 + self.lr_init * self.l2_reg * self.t)

        # Compute w^Tx + b
        sum = 0
        for index, value in example.items():
            sum += self.get_weight(index) * value # Note that it is not necessary to scale.
        sum += self.bias

        # Update weights and bias
        self.scale *= (1 - lr * self.l2_reg)
        g = logistic_gradient(label * sum)
        for index, value in example.items():
            self.increment_weight(index, -1 * lr * label * g * value / self.scale)
        if not self.no_bias:
            self.bias -= lr * label * g
        self.t += 1

    # This function zeroes out all but the top K weights.
    # In main.py, whenever this function is used, we first
    # make a copy of the classifier.
    def sparsify_top_k(self, k):
        if (k == -1) or (k > len(self.weights)):
            print("Number of nonzero features: ", len(self.weights))
            return -1

        # Sort the nonzero weights in decreasing order by absolute value.
        index_value_pairs = []
        for idx, weight in self.weights.items():
            index_value_pairs.append((idx, weight))

        def key(idx_weight_tuple):
            idx, weight = idx_weight_tuple
            return (-1) * np.abs(weight)
        
        index_value_pairs.sort(key=key)

        # Truncate the sorted weights, to the top k, and reset the
        # weights of the model to only include these truncated weights.
        index_value_pairs = index_value_pairs[:k]
        self.weights = dict()
        for pair in index_value_pairs:
            idx, weight = pair
            self.weights[idx] = weight
        
        return 0